import numpy as np
from multiagent.core import World, Agent, Landmark
from multiagent.scenario import BaseScenario
from multiagent.scenarios.hunter_invader_para import ScenePara

# if ScenePara.visualize:
#     import cv2
#     cv2.namedWindow('hunter invader')
    

def Norm(x):# 求长度
    return np.linalg.norm(x)

def Unit(m):
    return m*0.05
def ToMeter(x):
    return x/0.05

def assert_and_break(cond):
    if cond:
        return
    else:
        print("fail!")

class Scenario(BaseScenario):
    def __init__(self, num_agents=3, dist_threshold=0.1, arena_size=1, identity_size=0, process_index=0):
        self.num_agents = ScenePara.num_agent
        self.arena_size = ScenePara.arena_size
        self.caught = 0
        self.rew_other = 0
        self.manual_render = None
        self.discrete_action = ScenePara.discrete_action
        self.cam_range = ScenePara.arena_size * 1.2
        self.process_index = process_index  # process_index = 0 是主进程
        if process_index == 0: 
            # communicate with plot server on MATLAB via UDP socket
            from mCOMv4 import mUDP_client
            try:
                self.uc = mUDP_client(path = "/home/fuqingxu/Desktop/com/visualize/", digit=4, rapid_flush = True)
            except:
                self.uc = mUDP_client(path = "/home/qingxu/Desktop/com/visualize/", digit=4, rapid_flush = True)
            self.p0_num_ep = 0
            self.uc.v2d_init()
        # process_index > 0 是 其他进程

    def make_world(self):
        world = World()
        # set any world properties first
        world.dim_c = 2
        self.num_good_agents = ScenePara.hunter_num
        self.num_hunters = ScenePara.hunter_num
        self.num_adversaries = ScenePara.invader_num
        self.num_invaders = ScenePara.invader_num
        num_agents = self.num_agents
        num_landmarks = ScenePara.num_landmarks
        self.num_landmarks = num_landmarks
        # add agents, 包括 hunter 和 invader
        world.agents = [Agent() for i in range(num_agents)]
        for i, agent in enumerate(world.agents):
            agent.name = 'agent %d' % i
            agent.collide = True
            agent.silent = True
            # 前面的 num_adversaries 是 Invader，剩下的是 Hunter， Invader is adversary
            agent.IsInvader = True if i < self.num_adversaries else False
            agent.adversary = True if i < self.num_adversaries else False
            agent.size = ScenePara.Invader_Size  if agent.IsInvader else ScenePara.Hunter_Size      # size 中的数值是半径
            agent.accel =ScenePara.Invader_Accel if agent.IsInvader else ScenePara.Hunter_Accel
            agent.max_speed = ScenePara.Invader_MaxSpeed if agent.IsInvader else ScenePara.Hunter_MaxSpeed
            agent.indx = i
            agent.caught = False
            #agent.life = 1
            agent.live = 1
            agent.live_adv = 1
        # add landmarks
        world.landmarks = [Landmark() for i in range(num_landmarks)]
        for i, landmark in enumerate(world.landmarks):
            landmark.name = 'landmark %d' % i
            landmark.collide = False
            landmark.movable = False
            landmark.size = ScenePara.Landmark_Size
            landmark.boundary = False
        # make initial conditions
        self.reset_world(world)
        world.max_steps_episode = ScenePara.max_steps_episode
        self.hunters  = [agent for agent in world.agents if not agent.IsInvader]
        self.invaders = [agent for agent in world.agents if agent.IsInvader]
        self.landmarks = world.landmarks
        return world

    def rand(self, low, high):
        return np.random.rand()*(high - low) + low

    def landmark_spawn_position(self, landmark, world):
        landmark.state.p_pos = ScenePara.nest_center_pos
        # landmark.state.p_pos = np.random.uniform(-0.1*ScenePara.hunter_spawn_pos_lim,  \
        #     + 0.1*ScenePara.hunter_spawn_pos_lim, world.dim_p) \
        #     + ScenePara.nest_center_pos
        landmark.state.p_vel = np.zeros(world.dim_p)

    def spawn_position(self, agent, world):
        if not agent.IsInvader:
            # 处理hunter
            # 初始化，随机地分布在一个正方形内
            agent.state.p_pos = np.random.uniform( -ScenePara.hunter_spawn_pos_lim, 
                                                    ScenePara.hunter_spawn_pos_lim, 
                                                    world.dim_p) + ScenePara.nest_center_pos
            # 速度，初始化为0
            agent.state.p_vel = np.zeros(world.dim_p)
            agent.state.c = np.zeros(world.dim_c)
            agent.caught = False
            agent.live = 1
            agent.live_adv = 1
            agent.collide = True
            agent.movable = True
        else:
            # 处理invader
            # spawn direction relative to nest
            while True:
                theta = np.random.rand()*2*np.pi - np.pi
                d = self.rand(low=1.0*ScenePara.invader_spawn_limit, high=1.1*ScenePara.invader_spawn_limit)
                agent.state.p_pos = d * np.array([np.cos(theta),np.sin(theta)]) + ScenePara.nest_center_pos
                x = agent.state.p_pos[0]
                y = agent.state.p_pos[1]
                if x<self.arena_size and x>-self.arena_size and y<self.arena_size and y>-self.arena_size:
                    break
                else:
                    pass
                    # print('spawn outside map')

            # 速度，初始化为0
            agent.state.p_vel = np.zeros(world.dim_p)
            agent.state.c = np.zeros(world.dim_c)
            agent.caught = False
            agent.live = 1
            agent.live_adv = 1
            agent.collide = True
            agent.movable = True
        # if self.manual_render is not None: 
        #     self.manual_render()
        #     print("hi")
    def invader_revise(self, agent, world):
        self.spawn_position(agent,world)

    def reset_world(self, world):
        for i, agent in enumerate(world.agents):
            if agent.IsInvader:
                agent.color = np.array([0.85, 0.35, 0.35]) # red ?
            else:
                agent.color = np.array([0.35, 0.85, 0.35]) # blue ?
        for i, landmark in enumerate(world.landmarks):
            landmark.color = np.array([0.25, 0.25, 0.25])

        for agent in world.agents:
            self.spawn_position(agent, world)

        for i, landmark in enumerate(world.landmarks):
            self.landmark_spawn_position(landmark, world)

        world.steps = 0
        self.caught = 0
        self.rew_other = 0
        self.hunter_failed = False
        self.off_error = False

        # 减少记录的局数
        if self.process_index == 0:
            self.p0_num_ep += 1
        # if ScenePara.visualize:
        #     visualize(world.agents+world.landmarks)
            
        # world.steps >= world.max_steps_episode 时，或者success时，done被设置为True


    def benchmark_data(self, agent, world):
        # returns data for benchmarking purposes
        if agent.IsInvader:
            collisions = 0
            for a in self.hunters(world):
                if self.is_collision(a, agent):
                    collisions += 1
            return collisions
        else:
            return 0


    def is_collision(self, agent1, agent2): # 检测是否碰到了一起
        delta_pos = agent1.state.p_pos - agent2.state.p_pos
        dist = np.sqrt(np.sum(np.square(delta_pos)))
        dist_min = agent1.size + agent2.size
        return True if dist < dist_min else False


    def is_dectective(self, agent1, agent2): # A 和 B 之间的距离是否能相互可见
        delta_pos = agent1.state.p_pos - agent2.state.p_pos
        dist = np.sqrt(np.sum(np.square(delta_pos)))
        #dist_min = agent1.size + agent2.size
        return True if dist < ScenePara.distance_dectection else False

    # 按顺序 return all agents that are not invaders
    def hunters(self, world):
        return [agent for agent in world.agents if not agent.IsInvader]

    # 按顺序 return all adversarial agents
    def invaders(self, world):
        return [agent for agent in world.agents if agent.IsInvader]

    def uc_render(self, hunters, invaders, landmarks):
        for a,A in enumerate(invaders):
            x = A.state.p_pos[0]
            y = A.state.p_pos[1]
            self.uc.v2d('Type1--%d'%a,ToMeter(x),ToMeter(y),0) # function v2d(name, xpos, ypos, dir)

        for a,A in enumerate(hunters):
            x = A.state.p_pos[0]
            y = A.state.p_pos[1]
            self.uc.v2d('Type2--%d'%a,ToMeter(x),ToMeter(y),0) # function v2d(name, xpos, ypos, dir)

        for a,A in enumerate(landmarks):
            x = A.state.p_pos[0]
            y = A.state.p_pos[1]
            self.uc.v2d('Type3--%d'%a,ToMeter(x),ToMeter(y),0) # function v2d(name, xpos, ypos, dir)

        return

    def reward_forall(self, agents ,world):
        # 获取智能体列表        
        hunters  = self.hunters
        invaders = self.invaders
        landmars = self.landmarks
        ######### !!!!!!!!!!!!!!
        if self.process_index == 0 and self.p0_num_ep%2 == 1:
            self.uc_render(hunters, invaders, landmars)

        # 初始化奖励列表
        hunter_reward  = np.array([0] * self.num_hunters )
        invader_reward = np.array([0] * self.num_invaders)
        # 计算距离矩阵（agent2agent）
        distance = np.zeros(shape=(self.num_hunters, self.num_invaders), dtype=np.float32)
        # 计算距离矩阵（invader2landmark）
        distance_landmark = np.zeros(shape=(self.num_invaders, self.num_landmarks), dtype=np.float32)
        for b,B in enumerate(invaders):
            for a,A in enumerate(hunters):
                distance[a,b] = Norm( A.state.p_pos - B.state.p_pos )
            for c,C in enumerate(landmars):
                distance_landmark[b,c] = Norm( B.state.p_pos - C.state.p_pos )



        # hunter 的奖励有如下几条
        # <1> - 根据invader距离最近的landmark的距离受到惩罚
        min_min_distance = np.min(distance_landmark)
        Penalty_distance = 0.1 / (min_min_distance+0.05) if min_min_distance < Unit(m=20) else 0

        # 对所有智能体惩罚
        hunter_reward = hunter_reward - Penalty_distance

        ## invader 的奖励和 hunter相反？
        invader_reward = invader_reward + Penalty_distance
        
        # <2> - +++ hunter围捕invader获得奖励。根据invader被围捕时距离最近的landmark的距离受到奖励
        for invader in invaders: 
            invader.tracked_by = []
            invader.max_speed = ScenePara.Invader_MaxSpeed

        for i,hunter in enumerate(hunters):
            hunter_index = i
            # 滤除不活跃的
            if not hunter.live:
                continue
            # 二重循环，对出现在周围的invader造成减速
            for j,invader in enumerate(invaders):
                distance_i_j = distance[i,j]

                # filter dead ones
                if not invader.live:
                    continue
                # filter far away ones
                if distance_i_j > ScenePara.hunter_affect_range:
                    continue
                # now distance_i_j <= ScenePara.hunter_affect_range
        ########!!!!!!!!!!!!!!!!!!!!!!!
                if self.process_index == 0 and self.p0_num_ep%2 == 1:
                    x1 = hunter.state.p_pos[0]
                    y1 = hunter.state.p_pos[1]
                    x2 = invader.state.p_pos[0]
                    y2 = invader.state.p_pos[1]
                    self.uc.v2L('TypeL--%d+%d'%(i,j), ToMeter(x1), ToMeter(y1), ToMeter(x2), ToMeter(y2), "D=%.1f"%ToMeter(distance_i_j) )   # v2L(name, x1, y1, x2,  y2, str)

                invader.max_speed -= ScenePara.hunter_speed_pressure
                invader.tracked_by.append(hunter_index)

        # 看看哪些invader被足够多的hunter围攻，判定拦截成功
        for invader_index, invader in enumerate(invaders):
            for hunter_index in invader.tracked_by:
                # basic reward
                hunter_reward[hunter_index] += 0.5

            if len(invader.tracked_by) >= ScenePara.intercept_hunter_needed:
                invader.live = False
                # 根据 invader被拦截时 距离landmark的安全距离，赋予奖励
                #distance_list = distance_landmark[invader_index,:]
                #min_distance = np.min(distance_list)
                #Reward_distance = min_distance
                for hunter_index in invader.tracked_by:
                    # basic reward
                    hunter_reward[hunter_index] += 5
                    # additional reward
                    #hunter_reward[hunter_index] += Reward_distance

                invader_reward[invader_index] -= 5
                # return invader to starting point
                # ?
                self.invader_revise(invader, world)

        # ----- invader接触landmark，直接失败
        #for invader_index,invader in enumerate(invaders):
        #    if invader.live == False:
        #        continue
        #    distance_list = distance_landmark[invader_index,:]
        #    min_distance = np.min(distance_list)
        min_distance = np.min(distance_landmark)
        #if self.AnyInvaderReachLanderMark(invaders,landmars):
        if min_distance <= ScenePara.Invader_Kill_Range:
            hunter_reward = hunter_reward - 20
            invader_reward = invader_reward + 40
            self.hunter_failed = True
        
        # 所有invader都异常地远离了lanmark
        if min_min_distance > ScenePara.arena_size * 1.2:
            self.off_error = True


        ######!!!!!!!
        if self.process_index == 0 and self.p0_num_ep%2 == 1:
            self.uc.title( 'min-min-dist:%.1f, vs %.1f' \
                          %( 
                              ToMeter(min_min_distance), 
                              ToMeter(ScenePara.arena_size) * 1.2
                            )  
                        )
            self.uc.xlabel('episode:%d , step:%d'%(self.p0_num_ep,world.steps))
            self.uc.pause(0.05)

        return invader_reward.tolist() + hunter_reward.tolist()

    def done(self, agent, world):
        condition1 = world.steps >= world.max_steps_episode
        self.is_success = False if self.hunter_failed else True
        condition2 = self.off_error
        return condition1 or condition2 or self.hunter_failed

    def load_obs(self, obs, obs_pointer, fragment):
        L = len(fragment) if isinstance(fragment,np.ndarray) \
            else 1
        obs[obs_pointer:obs_pointer+L] = fragment
        return obs, obs_pointer+L

    def observation(self, agent, world):
        # therefore, the observation goes as follows:
        # 1. fully observable
        # 2. self centered
        dimension = 2+2+1+4*(self.num_agents-1)+2*len(world.landmarks) 
        # 2自身位置，2自身速度，1自身存活，2×M个landmark位置，4×(N-1)个其余智能体位置速度

        # 初始化obs
        obs, obs_pointer = ( np.zeros(shape=(dimension,)),  0 )
        # 装载obs
        obs, obs_pointer = self.load_obs(obs, obs_pointer, agent.state.p_vel)
        obs, obs_pointer = self.load_obs(obs, obs_pointer, agent.state.p_pos)


        for entity in world.landmarks:
            offset_ = entity.state.p_pos - agent.state.p_pos
            obs, obs_pointer = self.load_obs(obs, obs_pointer, offset_)

        for other_agent in world.agents:
            pos_offset = other_agent.state.p_pos - agent.state.p_pos
            vel_ = other_agent.state.p_vel

            if other_agent == agent:
                continue

            if other_agent.live == 0: #dead
                pos_offset = np.array([0, 0])
                vel_ = np.array([0, 0])
            obs, obs_pointer = self.load_obs(obs, obs_pointer, pos_offset)
            obs, obs_pointer = self.load_obs(obs, obs_pointer, vel_)
        
        obs, obs_pointer = self.load_obs(obs, obs_pointer, agent.live)  # agent.live: 0 dead, 1 alive

        assert_and_break(not np.isnan(obs).any())
        return obs



    def info(self, agent, world):
        return {'hunter_failed': self.hunter_failed, 'world_steps': world.steps,'is_success': self.is_success}
                #'reward':self.joint_reward, 'dists':self.delta_dists.mean()}
 
